{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Finetune LLMs with GRPO\n",
        "\n",
        "This notebook shows how to finetune an LLM with GRPO, using the `trl` library.\n",
        "\n",
        "It's by [Ben Burtenshaw](https://huggingface.co/burtenshaw) and [Maxime Labonne](https://huggingface.co/mlabonne).\n",
        "\n",
        "This is a minimal example. For a complete example, refer to the GRPO chapter in the [course](https://huggingface.co/course/en/chapter12/1).\n",
        "\n",
        "## Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l3IstgzN63QW"
      },
      "outputs": [],
      "source": [
        "!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off\n",
        "!pip install -qqq flash-attn --no-build-isolation --progress-bar off"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "5Y-X13wB7UP4"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import wandb\n",
        "from datasets import load_dataset\n",
        "from peft import LoraConfig, get_peft_model\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "from trl import GRPOConfig, GRPOTrainer\n",
        "\n",
        "# Log to Weights & Biases\n",
        "wandb.login()\n",
        "\n",
        "# Load dataset\n",
        "dataset = load_dataset(\"mlabonne/smoltldr\")\n",
        "print(dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3tLRvi5i-Qls"
      },
      "outputs": [],
      "source": [
        "# Load model\n",
        "model_id = \"HuggingFaceTB/SmolLM-135M-Instruct\"\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    torch_dtype=\"auto\",\n",
        "    device_map=\"auto\",\n",
        "    attn_implementation=\"flash_attention_2\",\n",
        ")\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "\n",
        "# Load LoRA\n",
        "lora_config = LoraConfig(\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    r=16,\n",
        "    lora_alpha=32,\n",
        "    target_modules=\"all-linear\",\n",
        ")\n",
        "model = get_peft_model(model, lora_config)\n",
        "print(model.print_trainable_parameters())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Define Reward Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "745L0RC6-XBT"
      },
      "outputs": [],
      "source": [
        "# Reward function\n",
        "def reward_len(completions, **kwargs):\n",
        "    return [-abs(50 - len(completion)) for completion in completions]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Define Training Arguments"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Training arguments\n",
        "training_args = GRPOConfig(\n",
        "    output_dir=\"GRPO\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=8,\n",
        "    gradient_accumulation_steps=2,\n",
        "    max_prompt_length=512,\n",
        "    max_completion_length=96,\n",
        "    num_generations=8,\n",
        "    optim=\"adamw_8bit\",\n",
        "    num_train_epochs=1,\n",
        "    bf16=True,\n",
        "    report_to=[\"wandb\"],\n",
        "    remove_unused_columns=False,\n",
        "    logging_steps=1,\n",
        ")\n",
        "\n",
        "# Trainer\n",
        "trainer = GRPOTrainer(\n",
        "    model=model,\n",
        "    reward_funcs=[reward_len],\n",
        "    args=training_args,\n",
        "    train_dataset=dataset[\"train\"],\n",
        ")\n",
        "\n",
        "# Train model\n",
        "wandb.init(project=\"GRPO\")\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Push Model to Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oKHhpA4z-sRF"
      },
      "outputs": [],
      "source": [
        "# Save model\n",
        "merged_model = trainer.model.merge_and_unload()\n",
        "merged_model.push_to_hub(\"<your-model-id>\", private=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Generate Text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = \"\"\"\n",
        "# A long document about the Cat\n",
        "\n",
        "The cat (Felis catus), also referred to as the domestic cat or house cat, is a small \n",
        "domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.\n",
        "Advances in archaeology and genetics have shown that the domestication of the cat occurred\n",
        "in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges\n",
        "freely as a feral cat avoiding human contact. It is valued by humans for companionship and\n",
        "its ability to kill vermin. Its retractable claws are adapted to killing small prey species\n",
        "such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,\n",
        "and its night vision and sense of smell are well developed. It is a social species,\n",
        "but a solitary hunter and a crepuscular predator. Cat communication includes\n",
        "vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as\n",
        "well as body language. It can hear sounds too faint or too high in frequency for human ears,\n",
        "such as those made by small mammals. It secretes and perceives pheromones.\n",
        "\"\"\"\n",
        "\n",
        "messages = [\n",
        "    {\"role\": \"user\", \"content\": prompt},\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6jbz8DYd-o7A"
      },
      "outputs": [],
      "source": [
        "# Generate text\n",
        "from transformers import pipeline\n",
        "\n",
        "generator = pipeline(\"text-generation\", model=\"<your-model-id>\")\n",
        "\n",
        "## Or use the model and tokenizer we defined earlier\n",
        "# generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
        "\n",
        "generate_kwargs = {\n",
        "    \"max_new_tokens\": 256,\n",
        "    \"do_sample\": True,\n",
        "    \"temperature\": 0.5,\n",
        "    \"min_p\": 0.1,\n",
        "}\n",
        "\n",
        "generated_text = generator(messages, generate_kwargs=generate_kwargs)\n",
        "\n",
        "print(generated_text)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
